{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ff2b6513",
   "metadata": {
    "origin_pos": 1
   },
   "source": [
    "# Utility Functions and Classes\n",
    ":label:`sec_utils`\n",
    "\n",
    "\n",
    "This section contains the implementations of utility functions and classes used in this book.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3d3fe5bb",
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "1"
    },
    "execution": {
     "iopub.execute_input": "2023-08-18T19:28:10.024196Z",
     "iopub.status.busy": "2023-08-18T19:28:10.023668Z",
     "iopub.status.idle": "2023-08-18T19:28:12.807027Z",
     "shell.execute_reply": "2023-08-18T19:28:12.805588Z"
    },
    "origin_pos": 3,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "import collections\n",
    "import inspect\n",
    "from IPython import display\n",
    "from torch import nn\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1296db3f",
   "metadata": {
    "origin_pos": 6
   },
   "source": [
    "Hyperparameters.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "98bb0ea8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:28:12.814437Z",
     "iopub.status.busy": "2023-08-18T19:28:12.813801Z",
     "iopub.status.idle": "2023-08-18T19:28:12.820512Z",
     "shell.execute_reply": "2023-08-18T19:28:12.819348Z"
    },
    "origin_pos": 7,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "@d2l.add_to_class(d2l.HyperParameters)  #@save\n",
    "def save_hyperparameters(self, ignore=[]):\n",
    "    \"\"\"Save function arguments into class attributes.\"\"\"\n",
    "    frame = inspect.currentframe().f_back\n",
    "    _, _, _, local_vars = inspect.getargvalues(frame)\n",
    "    self.hparams = {k:v for k, v in local_vars.items()\n",
    "                    if k not in set(ignore+['self']) and not k.startswith('_')}\n",
    "    for k, v in self.hparams.items():\n",
    "        setattr(self, k, v)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "baddd421",
   "metadata": {
    "origin_pos": 8
   },
   "source": [
    "Progress bar.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "885db4c2",
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "22"
    },
    "execution": {
     "iopub.execute_input": "2023-08-18T19:28:12.824277Z",
     "iopub.status.busy": "2023-08-18T19:28:12.823984Z",
     "iopub.status.idle": "2023-08-18T19:28:12.835713Z",
     "shell.execute_reply": "2023-08-18T19:28:12.834666Z"
    },
    "origin_pos": 9,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "@d2l.add_to_class(d2l.ProgressBoard)  #@save\n",
    "def draw(self, x, y, label, every_n=1):\n",
    "    Point = collections.namedtuple('Point', ['x', 'y'])\n",
    "    if not hasattr(self, 'raw_points'):\n",
    "        self.raw_points = collections.OrderedDict()\n",
    "        self.data = collections.OrderedDict()\n",
    "    if label not in self.raw_points:\n",
    "        self.raw_points[label] = []\n",
    "        self.data[label] = []\n",
    "    points = self.raw_points[label]\n",
    "    line = self.data[label]\n",
    "    points.append(Point(x, y))\n",
    "    if len(points) != every_n:\n",
    "        return\n",
    "    mean = lambda x: sum(x) / len(x)\n",
    "    line.append(Point(mean([p.x for p in points]),\n",
    "                      mean([p.y for p in points])))\n",
    "    points.clear()\n",
    "    if not self.display:\n",
    "        return\n",
    "    d2l.use_svg_display()\n",
    "    if self.fig is None:\n",
    "        self.fig = d2l.plt.figure(figsize=self.figsize)\n",
    "    plt_lines, labels = [], []\n",
    "    for (k, v), ls, color in zip(self.data.items(), self.ls, self.colors):\n",
    "        plt_lines.append(d2l.plt.plot([p.x for p in v], [p.y for p in v],\n",
    "                                      linestyle=ls, color=color)[0])\n",
    "        labels.append(k)\n",
    "    axes = self.axes if self.axes else d2l.plt.gca()\n",
    "    if self.xlim: axes.set_xlim(self.xlim)\n",
    "    if self.ylim: axes.set_ylim(self.ylim)\n",
    "    if not self.xlabel: self.xlabel = self.x\n",
    "    axes.set_xlabel(self.xlabel)\n",
    "    axes.set_ylabel(self.ylabel)\n",
    "    axes.set_xscale(self.xscale)\n",
    "    axes.set_yscale(self.yscale)\n",
    "    axes.legend(plt_lines, labels)\n",
    "    display.display(self.fig)\n",
    "    display.clear_output(wait=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "463a57b9",
   "metadata": {
    "origin_pos": 10
   },
   "source": [
    "Add FrozenLake enviroment\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bd8e6dad",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:28:12.840420Z",
     "iopub.status.busy": "2023-08-18T19:28:12.839831Z",
     "iopub.status.idle": "2023-08-18T19:28:12.846955Z",
     "shell.execute_reply": "2023-08-18T19:28:12.846019Z"
    },
    "origin_pos": 11,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def frozen_lake(seed): #@save\n",
    "    # See https://www.gymlibrary.dev/environments/toy_text/frozen_lake/ to learn more about this env\n",
    "    # How to process env.P.items is adpated from https://sites.google.com/view/deep-rl-bootcamp/labs\n",
    "    import gym\n",
    "\n",
    "    env = gym.make('FrozenLake-v1', is_slippery=False)\n",
    "    env.seed(seed)\n",
    "    env.action_space.np_random.seed(seed)\n",
    "    env.action_space.seed(seed)\n",
    "    env_info = {}\n",
    "    env_info['desc'] = env.desc  # 2D array specifying what each grid item means\n",
    "    env_info['num_states'] = env.nS  # Number of observations/states or obs/state dim\n",
    "    env_info['num_actions'] = env.nA  # Number of actions or action dim\n",
    "    # Define indices for (transition probability, nextstate, reward, done) tuple\n",
    "    env_info['trans_prob_idx'] = 0  # Index of transition probability entry\n",
    "    env_info['nextstate_idx'] = 1  # Index of next state entry\n",
    "    env_info['reward_idx'] = 2  # Index of reward entry\n",
    "    env_info['done_idx'] = 3  # Index of done entry\n",
    "    env_info['mdp'] = {}\n",
    "    env_info['env'] = env\n",
    "\n",
    "    for (s, others) in env.P.items():\n",
    "        # others(s) = {a0: [ (p(s'|s,a0), s', reward, done),...], a1:[...], ...}\n",
    "\n",
    "        for (a, pxrds) in others.items():\n",
    "            # pxrds is [(p1,next1,r1,d1),(p2,next2,r2,d2),..].\n",
    "            # e.g. [(0.3, 0, 0, False), (0.3, 0, 0, False), (0.3, 4, 1, False)]\n",
    "            env_info['mdp'][(s,a)] = pxrds\n",
    "\n",
    "    return env_info"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "472118b3",
   "metadata": {
    "origin_pos": 12
   },
   "source": [
    "Create enviroment\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6aa42d02",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:28:12.850455Z",
     "iopub.status.busy": "2023-08-18T19:28:12.849911Z",
     "iopub.status.idle": "2023-08-18T19:28:12.854684Z",
     "shell.execute_reply": "2023-08-18T19:28:12.853882Z"
    },
    "origin_pos": 13,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def make_env(name ='', seed=0): #@save\n",
    "    # Input parameters:\n",
    "    # name: specifies a gym environment.\n",
    "    # For Value iteration, only FrozenLake-v1 is supported.\n",
    "    if name == 'FrozenLake-v1':\n",
    "        return frozen_lake(seed)\n",
    "\n",
    "    else:\n",
    "        raise ValueError(\"%s env is not supported in this Notebook\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f8b194e",
   "metadata": {
    "origin_pos": 14
   },
   "source": [
    "Show value function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "164d4c50",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:28:12.858698Z",
     "iopub.status.busy": "2023-08-18T19:28:12.858267Z",
     "iopub.status.idle": "2023-08-18T19:28:12.873620Z",
     "shell.execute_reply": "2023-08-18T19:28:12.872821Z"
    },
    "origin_pos": 15,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def show_value_function_progress(env_desc, V, pi): #@save\n",
    "    # This function visualizes how value and policy changes over time.\n",
    "    # V: [num_iters, num_states]\n",
    "    # pi: [num_iters, num_states]\n",
    "    # How to visualize value function is adapted (but changed) from: https://sites.google.com/view/deep-rl-bootcamp/labs\n",
    "\n",
    "    num_iters = V.shape[0]\n",
    "    fig, ax  = plt.subplots(figsize=(15, 15))\n",
    "\n",
    "    for k in range(V.shape[0]):\n",
    "        plt.subplot(4, 4, k + 1)\n",
    "        plt.imshow(V[k].reshape(4,4), cmap=\"bone\")\n",
    "        ax = plt.gca()\n",
    "        ax.set_xticks(np.arange(0, 5)-.5, minor=True)\n",
    "        ax.set_yticks(np.arange(0, 5)-.5, minor=True)\n",
    "        ax.grid(which=\"minor\", color=\"w\", linestyle='-', linewidth=3)\n",
    "        ax.tick_params(which=\"minor\", bottom=False, left=False)\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "\n",
    "        # LEFT action: 0, DOWN action: 1\n",
    "        # RIGHT action: 2, UP action: 3\n",
    "        action2dxdy = {0:(-.25, 0),1: (0, .25),\n",
    "                       2:(0.25, 0),3: (-.25, 0)}\n",
    "\n",
    "        for y in range(4):\n",
    "            for x in range(4):\n",
    "                action = pi[k].reshape(4,4)[y, x]\n",
    "                dx, dy = action2dxdy[action]\n",
    "\n",
    "                if env_desc[y,x].decode() == 'H':\n",
    "                    ax.text(x, y, str(env_desc[y,x].decode()),\n",
    "                       ha=\"center\", va=\"center\", color=\"y\",\n",
    "                         size=20, fontweight='bold')\n",
    "\n",
    "                elif env_desc[y,x].decode() == 'G':\n",
    "                    ax.text(x, y, str(env_desc[y,x].decode()),\n",
    "                       ha=\"center\", va=\"center\", color=\"w\",\n",
    "                         size=20, fontweight='bold')\n",
    "\n",
    "                else:\n",
    "                    ax.text(x, y, str(env_desc[y,x].decode()),\n",
    "                       ha=\"center\", va=\"center\", color=\"g\",\n",
    "                         size=15, fontweight='bold')\n",
    "\n",
    "                # No arrow for cells with G and H labels\n",
    "                if env_desc[y,x].decode() != 'G' and env_desc[y,x].decode() != 'H':\n",
    "                    ax.arrow(x, y, dx, dy, color='r', head_width=0.2, head_length=0.15)\n",
    "\n",
    "        ax.set_title(\"Step = \"  + str(k + 1), fontsize=20)\n",
    "\n",
    "    fig.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c30d7cc9",
   "metadata": {
    "origin_pos": 16
   },
   "source": [
    "Show Q function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2622a762",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:28:12.877031Z",
     "iopub.status.busy": "2023-08-18T19:28:12.876452Z",
     "iopub.status.idle": "2023-08-18T19:28:12.889267Z",
     "shell.execute_reply": "2023-08-18T19:28:12.888394Z"
    },
    "origin_pos": 17,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def show_Q_function_progress(env_desc, V_all, pi_all): #@save\n",
    "    # This function visualizes how value and policy changes over time.\n",
    "    # V: [num_iters, num_states]\n",
    "    # pi: [num_iters, num_states]\n",
    "\n",
    "    # We want to only shows few values\n",
    "    num_iters_all = V_all.shape[0]\n",
    "    num_iters = num_iters_all // 10\n",
    "\n",
    "    vis_indx = np.arange(0, num_iters_all, num_iters).tolist()\n",
    "    vis_indx.append(num_iters_all - 1)\n",
    "    V = np.zeros((len(vis_indx), V_all.shape[1]))\n",
    "    pi = np.zeros((len(vis_indx), V_all.shape[1]))\n",
    "\n",
    "    for c, i in enumerate(vis_indx):\n",
    "        V[c]  = V_all[i]\n",
    "        pi[c] = pi_all[i]\n",
    "\n",
    "    num_iters = V.shape[0]\n",
    "    fig, ax = plt.subplots(figsize=(15, 15))\n",
    "\n",
    "    for k in range(V.shape[0]):\n",
    "        plt.subplot(4, 4, k + 1)\n",
    "        plt.imshow(V[k].reshape(4,4), cmap=\"bone\")\n",
    "        ax = plt.gca()\n",
    "        ax.set_xticks(np.arange(0, 5)-.5, minor=True)\n",
    "        ax.set_yticks(np.arange(0, 5)-.5, minor=True)\n",
    "        ax.grid(which=\"minor\", color=\"w\", linestyle='-', linewidth=3)\n",
    "        ax.tick_params(which=\"minor\", bottom=False, left=False)\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "\n",
    "        # LEFT action: 0, DOWN action: 1\n",
    "        # RIGHT action: 2, UP action: 3\n",
    "        action2dxdy = {0:(-.25, 0),1:(0, .25),\n",
    "                       2:(0.25, 0),3:(-.25, 0)}\n",
    "\n",
    "        for y in range(4):\n",
    "            for x in range(4):\n",
    "                action = pi[k].reshape(4,4)[y, x]\n",
    "                dx, dy = action2dxdy[action]\n",
    "\n",
    "                if env_desc[y,x].decode() == 'H':\n",
    "                    ax.text(x, y, str(env_desc[y,x].decode()),\n",
    "                       ha=\"center\", va=\"center\", color=\"y\",\n",
    "                         size=20, fontweight='bold')\n",
    "\n",
    "                elif env_desc[y,x].decode() == 'G':\n",
    "                    ax.text(x, y, str(env_desc[y,x].decode()),\n",
    "                       ha=\"center\", va=\"center\", color=\"w\",\n",
    "                         size=20, fontweight='bold')\n",
    "\n",
    "                else:\n",
    "                    ax.text(x, y, str(env_desc[y,x].decode()),\n",
    "                       ha=\"center\", va=\"center\", color=\"g\",\n",
    "                         size=15, fontweight='bold')\n",
    "\n",
    "                # No arrow for cells with G and H labels\n",
    "                if env_desc[y,x].decode() != 'G' and env_desc[y,x].decode() != 'H':\n",
    "                    ax.arrow(x, y, dx, dy, color='r', head_width=0.2, head_length=0.15)\n",
    "\n",
    "        ax.set_title(\"Step = \"  + str(vis_indx[k] + 1), fontsize=20)\n",
    "\n",
    "    fig.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3af7be4b",
   "metadata": {
    "origin_pos": 18
   },
   "source": [
    "Trainer\n",
    "\n",
    "A bunch of functions that will be deprecated:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1c3317ff",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:28:12.892709Z",
     "iopub.status.busy": "2023-08-18T19:28:12.892441Z",
     "iopub.status.idle": "2023-08-18T19:28:12.909614Z",
     "shell.execute_reply": "2023-08-18T19:28:12.908815Z"
    },
    "origin_pos": 20,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def load_array(data_arrays, batch_size, is_train=True):  #@save\n",
    "    \"\"\"Construct a PyTorch data iterator.\"\"\"\n",
    "    dataset = torch.utils.data.TensorDataset(*data_arrays)\n",
    "    return torch.utils.data.DataLoader(dataset, batch_size, shuffle=is_train)\n",
    "\n",
    "def synthetic_data(w, b, num_examples):  #@save\n",
    "    \"\"\"Generate y = Xw + b + noise.\"\"\"\n",
    "    X = torch.normal(0, 1, (num_examples, len(w)))\n",
    "    y = torch.matmul(X, w) + b\n",
    "    y += torch.normal(0, 0.01, y.shape)\n",
    "    return X, y.reshape((-1, 1))\n",
    "\n",
    "def sgd(params, lr, batch_size): #@save\n",
    "    \"\"\"Minibatch stochastic gradient descent.\"\"\"\n",
    "    with torch.no_grad():\n",
    "        for param in params:\n",
    "            param -= lr * param.grad / batch_size\n",
    "            param.grad.zero_()\n",
    "\n",
    "def get_dataloader_workers():  #@save\n",
    "    \"\"\"Use 4 processes to read the data.\"\"\"\n",
    "    return 4\n",
    "\n",
    "def load_data_fashion_mnist(batch_size, resize=None):  #@save\n",
    "    \"\"\"Download the Fashion-MNIST dataset and then load it into memory.\"\"\"\n",
    "    trans = [transforms.ToTensor()]\n",
    "    if resize:\n",
    "        trans.insert(0, transforms.Resize(resize))\n",
    "    trans = transforms.Compose(trans)\n",
    "    mnist_train = torchvision.datasets.FashionMNIST(\n",
    "        root=\"../data\", train=True, transform=trans, download=True)\n",
    "    mnist_test = torchvision.datasets.FashionMNIST(\n",
    "        root=\"../data\", train=False, transform=trans, download=True)\n",
    "    return (torch.utils.data.DataLoader(mnist_train, batch_size, shuffle=True,\n",
    "                                        num_workers=get_dataloader_workers()),\n",
    "            torch.utils.data.DataLoader(mnist_test, batch_size, shuffle=False,\n",
    "                                        num_workers=get_dataloader_workers()))\n",
    "\n",
    "def evaluate_accuracy_gpu(net, data_iter, device=None): #@save\n",
    "    \"\"\"Compute the accuracy for a model on a dataset using a GPU.\"\"\"\n",
    "    if isinstance(net, nn.Module):\n",
    "        net.eval()  # Set the model to evaluation mode\n",
    "        if not device:\n",
    "            device = next(iter(net.parameters())).device\n",
    "    # No. of correct predictions, no. of predictions\n",
    "    metric = d2l.Accumulator(2)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for X, y in data_iter:\n",
    "            if isinstance(X, list):\n",
    "                # Required for BERT Fine-tuning (to be covered later)\n",
    "                X = [x.to(device) for x in X]\n",
    "            else:\n",
    "                X = X.to(device)\n",
    "            y = y.to(device)\n",
    "            metric.add(d2l.accuracy(net(X), y), y.numel())\n",
    "    return metric[0] / metric[1]\n",
    "\n",
    "\n",
    "#@save\n",
    "def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):\n",
    "    \"\"\"Train a model with a GPU (defined in Chapter 6).\"\"\"\n",
    "    def init_weights(m):\n",
    "        if type(m) == nn.Linear or type(m) == nn.Conv2d:\n",
    "            nn.init.xavier_uniform_(m.weight)\n",
    "    net.apply(init_weights)\n",
    "    print('training on', device)\n",
    "    net.to(device)\n",
    "    optimizer = torch.optim.SGD(net.parameters(), lr=lr)\n",
    "    loss = nn.CrossEntropyLoss()\n",
    "    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],\n",
    "                            legend=['train loss', 'train acc', 'test acc'])\n",
    "    timer, num_batches = d2l.Timer(), len(train_iter)\n",
    "    for epoch in range(num_epochs):\n",
    "        # Sum of training loss, sum of training accuracy, no. of examples\n",
    "        metric = d2l.Accumulator(3)\n",
    "        net.train()\n",
    "        for i, (X, y) in enumerate(train_iter):\n",
    "            timer.start()\n",
    "            optimizer.zero_grad()\n",
    "            X, y = X.to(device), y.to(device)\n",
    "            y_hat = net(X)\n",
    "            l = loss(y_hat, y)\n",
    "            l.backward()\n",
    "            optimizer.step()\n",
    "            with torch.no_grad():\n",
    "                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])\n",
    "            timer.stop()\n",
    "            train_l = metric[0] / metric[2]\n",
    "            train_acc = metric[1] / metric[2]\n",
    "            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:\n",
    "                animator.add(epoch + (i + 1) / num_batches,\n",
    "                             (train_l, train_acc, None))\n",
    "        test_acc = evaluate_accuracy_gpu(net, test_iter)\n",
    "        animator.add(epoch + 1, (None, None, test_acc))\n",
    "    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '\n",
    "          f'test acc {test_acc:.3f}')\n",
    "    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '\n",
    "          f'on {str(device)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2a3c4dcf",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:28:12.913071Z",
     "iopub.status.busy": "2023-08-18T19:28:12.912410Z",
     "iopub.status.idle": "2023-08-18T19:28:12.919215Z",
     "shell.execute_reply": "2023-08-18T19:28:12.918183Z"
    },
    "origin_pos": 23,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save\n",
    "    \"\"\"Plot a list of images.\"\"\"\n",
    "    figsize = (num_cols * scale, num_rows * scale)\n",
    "    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)\n",
    "    axes = axes.flatten()\n",
    "    for i, (ax, img) in enumerate(zip(axes, imgs)):\n",
    "        try:\n",
    "            img = img.detach().numpy()\n",
    "        except:\n",
    "            pass\n",
    "        ax.imshow(img)\n",
    "        ax.axes.get_xaxis().set_visible(False)\n",
    "        ax.axes.get_yaxis().set_visible(False)\n",
    "        if titles:\n",
    "            ax.set_title(titles[i])\n",
    "    return axes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "251affa5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:28:12.922865Z",
     "iopub.status.busy": "2023-08-18T19:28:12.922058Z",
     "iopub.status.idle": "2023-08-18T19:28:12.938713Z",
     "shell.execute_reply": "2023-08-18T19:28:12.937549Z"
    },
    "origin_pos": 24,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def linreg(X, w, b):  #@save\n",
    "    \"\"\"The linear regression model.\"\"\"\n",
    "    return torch.matmul(X, w) + b\n",
    "\n",
    "def squared_loss(y_hat, y):  #@save\n",
    "    \"\"\"Squared loss.\"\"\"\n",
    "    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2\n",
    "\n",
    "def get_fashion_mnist_labels(labels):  #@save\n",
    "    \"\"\"Return text labels for the Fashion-MNIST dataset.\"\"\"\n",
    "    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',\n",
    "                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']\n",
    "    return [text_labels[int(i)] for i in labels]\n",
    "\n",
    "class Animator:  #@save\n",
    "    \"\"\"For plotting data in animation.\"\"\"\n",
    "    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,\n",
    "                 ylim=None, xscale='linear', yscale='linear',\n",
    "                 fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,\n",
    "                 figsize=(3.5, 2.5)):\n",
    "        # Incrementally plot multiple lines\n",
    "        if legend is None:\n",
    "            legend = []\n",
    "        d2l.use_svg_display()\n",
    "        self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)\n",
    "        if nrows * ncols == 1:\n",
    "            self.axes = [self.axes, ]\n",
    "        # Use a lambda function to capture arguments\n",
    "        self.config_axes = lambda: d2l.set_axes(\n",
    "            self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)\n",
    "        self.X, self.Y, self.fmts = None, None, fmts\n",
    "\n",
    "    def add(self, x, y):\n",
    "        # Add multiple data points into the figure\n",
    "        if not hasattr(y, \"__len__\"):\n",
    "            y = [y]\n",
    "        n = len(y)\n",
    "        if not hasattr(x, \"__len__\"):\n",
    "            x = [x] * n\n",
    "        if not self.X:\n",
    "            self.X = [[] for _ in range(n)]\n",
    "        if not self.Y:\n",
    "            self.Y = [[] for _ in range(n)]\n",
    "        for i, (a, b) in enumerate(zip(x, y)):\n",
    "            if a is not None and b is not None:\n",
    "                self.X[i].append(a)\n",
    "                self.Y[i].append(b)\n",
    "        self.axes[0].cla()\n",
    "        for x, y, fmt in zip(self.X, self.Y, self.fmts):\n",
    "            self.axes[0].plot(x, y, fmt)\n",
    "        self.config_axes()\n",
    "        display.display(self.fig)\n",
    "        display.clear_output(wait=True)\n",
    "\n",
    "class Accumulator:  #@save\n",
    "    \"\"\"For accumulating sums over `n` variables.\"\"\"\n",
    "    def __init__(self, n):\n",
    "        self.data = [0.0] * n\n",
    "\n",
    "    def add(self, *args):\n",
    "        self.data = [a + float(b) for a, b in zip(self.data, args)]\n",
    "\n",
    "    def reset(self):\n",
    "        self.data = [0.0] * len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.data[idx]\n",
    "\n",
    "\n",
    "def accuracy(y_hat, y):  #@save\n",
    "    \"\"\"Compute the number of correct predictions.\"\"\"\n",
    "    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:\n",
    "        y_hat = y_hat.argmax(axis=1)\n",
    "    cmp = y_hat.type(y.dtype) == y\n",
    "    return float(cmp.type(y.dtype).sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e8525488",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:28:12.942440Z",
     "iopub.status.busy": "2023-08-18T19:28:12.941567Z",
     "iopub.status.idle": "2023-08-18T19:28:12.951095Z",
     "shell.execute_reply": "2023-08-18T19:28:12.949973Z"
    },
    "origin_pos": 25,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "import hashlib\n",
    "import os\n",
    "import tarfile\n",
    "import zipfile\n",
    "import requests\n",
    "\n",
    "\n",
    "def download(url, folder='../data', sha1_hash=None):  #@save\n",
    "    \"\"\"Download a file to folder and return the local filepath.\"\"\"\n",
    "    if not url.startswith('http'):\n",
    "        # For back compatability\n",
    "        url, sha1_hash = DATA_HUB[url]\n",
    "    os.makedirs(folder, exist_ok=True)\n",
    "    fname = os.path.join(folder, url.split('/')[-1])\n",
    "    # Check if hit cache\n",
    "    if os.path.exists(fname) and sha1_hash:\n",
    "        sha1 = hashlib.sha1()\n",
    "        with open(fname, 'rb') as f:\n",
    "            while True:\n",
    "                data = f.read(1048576)\n",
    "                if not data:\n",
    "                    break\n",
    "                sha1.update(data)\n",
    "        if sha1.hexdigest() == sha1_hash:\n",
    "            return fname\n",
    "    # Download\n",
    "    print(f'Downloading {fname} from {url}...')\n",
    "    r = requests.get(url, stream=True, verify=True)\n",
    "    with open(fname, 'wb') as f:\n",
    "        f.write(r.content)\n",
    "    return fname\n",
    "\n",
    "def extract(filename, folder=None):  #@save\n",
    "    \"\"\"Extract a zip/tar file into folder.\"\"\"\n",
    "    base_dir = os.path.dirname(filename)\n",
    "    _, ext = os.path.splitext(filename)\n",
    "    assert ext in ('.zip', '.tar', '.gz'), 'Only support zip/tar files.'\n",
    "    if ext == '.zip':\n",
    "        fp = zipfile.ZipFile(filename, 'r')\n",
    "    else:\n",
    "        fp = tarfile.open(filename, 'r')\n",
    "    if folder is None:\n",
    "        folder = base_dir\n",
    "    fp.extractall(folder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "9ced27a3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:28:12.954444Z",
     "iopub.status.busy": "2023-08-18T19:28:12.953870Z",
     "iopub.status.idle": "2023-08-18T19:28:12.960111Z",
     "shell.execute_reply": "2023-08-18T19:28:12.959248Z"
    },
    "origin_pos": 26,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def download_extract(name, folder=None):  #@save\n",
    "    \"\"\"Download and extract a zip/tar file.\"\"\"\n",
    "    fname = download(name)\n",
    "    base_dir = os.path.dirname(fname)\n",
    "    data_dir, ext = os.path.splitext(fname)\n",
    "    if ext == '.zip':\n",
    "        fp = zipfile.ZipFile(fname, 'r')\n",
    "    elif ext in ('.tar', '.gz'):\n",
    "        fp = tarfile.open(fname, 'r')\n",
    "    else:\n",
    "        assert False, 'Only zip/tar files can be extracted.'\n",
    "    fp.extractall(base_dir)\n",
    "    return os.path.join(base_dir, folder) if folder else data_dir\n",
    "\n",
    "\n",
    "def tokenize(lines, token='word'):  #@save\n",
    "    \"\"\"Split text lines into word or character tokens.\"\"\"\n",
    "    assert token in ('word', 'char'), 'Unknown token type: ' + token\n",
    "    return [line.split() if token == 'word' else list(line) for line in lines]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "677b4659",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:28:12.963237Z",
     "iopub.status.busy": "2023-08-18T19:28:12.962688Z",
     "iopub.status.idle": "2023-08-18T19:28:12.967574Z",
     "shell.execute_reply": "2023-08-18T19:28:12.966691Z"
    },
    "origin_pos": 27,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def evaluate_loss(net, data_iter, loss):  #@save\n",
    "    \"\"\"Evaluate the loss of a model on the given dataset.\"\"\"\n",
    "    metric = d2l.Accumulator(2)  # Sum of losses, no. of examples\n",
    "    for X, y in data_iter:\n",
    "        out = net(X)\n",
    "        y = y.reshape(out.shape)\n",
    "        l = loss(out, y)\n",
    "        metric.add(l.sum(), l.numel())\n",
    "    return metric[0] / metric[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "05e8f2e3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:28:12.970822Z",
     "iopub.status.busy": "2023-08-18T19:28:12.970270Z",
     "iopub.status.idle": "2023-08-18T19:28:12.976073Z",
     "shell.execute_reply": "2023-08-18T19:28:12.975000Z"
    },
    "origin_pos": 29,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def grad_clipping(net, theta):  #@save\n",
    "    \"\"\"Clip the gradient.\"\"\"\n",
    "    if isinstance(net, nn.Module):\n",
    "        params = [p for p in net.parameters() if p.requires_grad]\n",
    "    else:\n",
    "        params = net.params\n",
    "    norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))\n",
    "    if norm > theta:\n",
    "        for param in params:\n",
    "            param.grad[:] *= theta / norm"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63051dcf",
   "metadata": {
    "origin_pos": 31
   },
   "source": [
    "More for the attention chapter.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "205a7cb5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:28:12.979520Z",
     "iopub.status.busy": "2023-08-18T19:28:12.979206Z",
     "iopub.status.idle": "2023-08-18T19:28:12.991515Z",
     "shell.execute_reply": "2023-08-18T19:28:12.990689Z"
    },
    "origin_pos": 32,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#@save\n",
    "d2l.DATA_HUB['fra-eng'] = (d2l.DATA_URL + 'fra-eng.zip',\n",
    "                           '94646ad1522d915e7b0f9296181140edcf86a4f5')\n",
    "\n",
    "#@save\n",
    "def read_data_nmt():\n",
    "    \"\"\"Load the English-French dataset.\"\"\"\n",
    "    data_dir = d2l.download_extract('fra-eng')\n",
    "    with open(os.path.join(data_dir, 'fra.txt'), 'r', encoding='utf-8') as f:\n",
    "        return f.read()\n",
    "\n",
    "#@save\n",
    "def preprocess_nmt(text):\n",
    "    \"\"\"Preprocess the English-French dataset.\"\"\"\n",
    "    def no_space(char, prev_char):\n",
    "        return char in set(',.!?') and prev_char != ' '\n",
    "\n",
    "    # Replace non-breaking space with space, and convert uppercase letters to\n",
    "    # lowercase ones\n",
    "    text = text.replace('\\u202f', ' ').replace('\\xa0', ' ').lower()\n",
    "    # Insert space between words and punctuation marks\n",
    "    out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char\n",
    "           for i, char in enumerate(text)]\n",
    "    return ''.join(out)\n",
    "\n",
    "#@save\n",
    "def tokenize_nmt(text, num_examples=None):\n",
    "    \"\"\"Tokenize the English-French dataset.\"\"\"\n",
    "    source, target = [], []\n",
    "    for i, line in enumerate(text.split('\\n')):\n",
    "        if num_examples and i > num_examples:\n",
    "            break\n",
    "        parts = line.split('\\t')\n",
    "        if len(parts) == 2:\n",
    "            source.append(parts[0].split(' '))\n",
    "            target.append(parts[1].split(' '))\n",
    "    return source, target\n",
    "\n",
    "\n",
    "#@save\n",
    "def truncate_pad(line, num_steps, padding_token):\n",
    "    \"\"\"Truncate or pad sequences.\"\"\"\n",
    "    if len(line) > num_steps:\n",
    "        return line[:num_steps]  # Truncate\n",
    "    return line + [padding_token] * (num_steps - len(line))  # Pad\n",
    "\n",
    "\n",
    "#@save\n",
    "def build_array_nmt(lines, vocab, num_steps):\n",
    "    \"\"\"Transform text sequences of machine translation into minibatches.\"\"\"\n",
    "    lines = [vocab[l] for l in lines]\n",
    "    lines = [l + [vocab['<eos>']] for l in lines]\n",
    "    array = torch.tensor([truncate_pad(\n",
    "        l, num_steps, vocab['<pad>']) for l in lines])\n",
    "    valid_len = (array != vocab['<pad>']).type(torch.int32).sum(1)\n",
    "    return array, valid_len\n",
    "\n",
    "\n",
    "#@save\n",
    "def load_data_nmt(batch_size, num_steps, num_examples=600):\n",
    "    \"\"\"Return the iterator and the vocabularies of the translation dataset.\"\"\"\n",
    "    text = preprocess_nmt(read_data_nmt())\n",
    "    source, target = tokenize_nmt(text, num_examples)\n",
    "    src_vocab = d2l.Vocab(source, min_freq=2,\n",
    "                          reserved_tokens=['<pad>', '<bos>', '<eos>'])\n",
    "    tgt_vocab = d2l.Vocab(target, min_freq=2,\n",
    "                          reserved_tokens=['<pad>', '<bos>', '<eos>'])\n",
    "    src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)\n",
    "    tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)\n",
    "    data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)\n",
    "    data_iter = d2l.load_array(data_arrays, batch_size)\n",
    "    return data_iter, src_vocab, tgt_vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "c9ab814a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:28:12.994838Z",
     "iopub.status.busy": "2023-08-18T19:28:12.994441Z",
     "iopub.status.idle": "2023-08-18T19:28:13.011201Z",
     "shell.execute_reply": "2023-08-18T19:28:13.010391Z"
    },
    "origin_pos": 34,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#@save\n",
    "def sequence_mask(X, valid_len, value=0):\n",
    "    \"\"\"Mask irrelevant entries in sequences.\"\"\"\n",
    "    maxlen = X.size(1)\n",
    "    mask = torch.arange((maxlen), dtype=torch.float32,\n",
    "                        device=X.device)[None, :] < valid_len[:, None]\n",
    "    X[~mask] = value\n",
    "    return X\n",
    "\n",
    "\n",
    "#@save\n",
    "class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):\n",
    "    \"\"\"The softmax cross-entropy loss with masks.\"\"\"\n",
    "    # `pred` shape: (`batch_size`, `num_steps`, `vocab_size`)\n",
    "    # `label` shape: (`batch_size`, `num_steps`)\n",
    "    # `valid_len` shape: (`batch_size`,)\n",
    "    def forward(self, pred, label, valid_len):\n",
    "        weights = torch.ones_like(label)\n",
    "        weights = sequence_mask(weights, valid_len)\n",
    "        self.reduction='none'\n",
    "        unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(\n",
    "            pred.permute(0, 2, 1), label)\n",
    "        weighted_loss = (unweighted_loss * weights).mean(dim=1)\n",
    "        return weighted_loss\n",
    "\n",
    "#@save\n",
    "def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):\n",
    "    \"\"\"Train a model for sequence to sequence.\"\"\"\n",
    "    def xavier_init_weights(m):\n",
    "        if type(m) == nn.Linear:\n",
    "            nn.init.xavier_uniform_(m.weight)\n",
    "        if type(m) == nn.GRU:\n",
    "            for param in m._flat_weights_names:\n",
    "                if \"weight\" in param:\n",
    "                    nn.init.xavier_uniform_(m._parameters[param])\n",
    "    net.apply(xavier_init_weights)\n",
    "    net.to(device)\n",
    "    optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "    loss = MaskedSoftmaxCELoss()\n",
    "    net.train()\n",
    "    animator = d2l.Animator(xlabel='epoch', ylabel='loss',\n",
    "                            xlim=[10, num_epochs])\n",
    "    for epoch in range(num_epochs):\n",
    "        timer = d2l.Timer()\n",
    "        metric = d2l.Accumulator(2)  # Sum of training loss, no. of tokens\n",
    "        for batch in data_iter:\n",
    "            optimizer.zero_grad()\n",
    "            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]\n",
    "            bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],\n",
    "                               device=device).reshape(-1, 1)\n",
    "            dec_input = torch.cat([bos, Y[:, :-1]], 1)  # Teacher forcing\n",
    "            Y_hat, _ = net(X, dec_input, X_valid_len)\n",
    "            l = loss(Y_hat, Y, Y_valid_len)\n",
    "            l.sum().backward()  # Make the loss scalar for `backward`\n",
    "            d2l.grad_clipping(net, 1)\n",
    "            num_tokens = Y_valid_len.sum()\n",
    "            optimizer.step()\n",
    "            with torch.no_grad():\n",
    "                metric.add(l.sum(), num_tokens)\n",
    "        if (epoch + 1) % 10 == 0:\n",
    "            animator.add(epoch + 1, (metric[0] / metric[1],))\n",
    "    print(f'loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} '\n",
    "          f'tokens/sec on {str(device)}')\n",
    "\n",
    "\n",
    "#@save\n",
    "def predict_seq2seq(net, src_sentence, src_vocab, tgt_vocab, num_steps,\n",
    "                    device, save_attention_weights=False):\n",
    "    \"\"\"Predict for sequence to sequence.\"\"\"\n",
    "    # Set `net` to eval mode for inference\n",
    "    net.eval()\n",
    "    src_tokens = src_vocab[src_sentence.lower().split(' ')] + [\n",
    "        src_vocab['<eos>']]\n",
    "    enc_valid_len = torch.tensor([len(src_tokens)], device=device)\n",
    "    src_tokens = d2l.truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])\n",
    "    # Add the batch axis\n",
    "    enc_X = torch.unsqueeze(\n",
    "        torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)\n",
    "    enc_outputs = net.encoder(enc_X, enc_valid_len)\n",
    "    dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)\n",
    "    # Add the batch axis\n",
    "    dec_X = torch.unsqueeze(torch.tensor(\n",
    "        [tgt_vocab['<bos>']], dtype=torch.long, device=device), dim=0)\n",
    "    output_seq, attention_weight_seq = [], []\n",
    "    for _ in range(num_steps):\n",
    "        Y, dec_state = net.decoder(dec_X, dec_state)\n",
    "        # We use the token with the highest prediction likelihood as input\n",
    "        # of the decoder at the next time step\n",
    "        dec_X = Y.argmax(dim=2)\n",
    "        pred = dec_X.squeeze(dim=0).type(torch.int32).item()\n",
    "        # Save attention weights (to be covered later)\n",
    "        if save_attention_weights:\n",
    "            attention_weight_seq.append(net.decoder.attention_weights)\n",
    "        # Once the end-of-sequence token is predicted, the generation of the\n",
    "        # output sequence is complete\n",
    "        if pred == tgt_vocab['<eos>']:\n",
    "            break\n",
    "        output_seq.append(pred)\n",
    "    return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "required_libs": []
 },
 "nbformat": 4,
 "nbformat_minor": 5
}